"""Tests for acme.standalone."""
import socket
import threading
import unittest

import josepy as jose
try:
    import mock
except ImportError: # pragma: no cover
    from unittest import mock # type: ignore
import requests
from six.moves import http_client  # pylint: disable=import-error
from six.moves import socketserver  # type: ignore  # pylint: disable=import-error

from acme import challenges
from acme import crypto_util
from acme import errors

import test_util


class TLSServerTest(unittest.TestCase):
    """Tests for acme.standalone.TLSServer."""


    def test_bind(self):  # pylint: disable=no-self-use
        from acme.standalone import TLSServer
        server = TLSServer(
            ('', 0), socketserver.BaseRequestHandler, bind_and_activate=True)
        server.server_close()

    def test_ipv6(self):
        if socket.has_ipv6:
            from acme.standalone import TLSServer
            server = TLSServer(
                ('', 0), socketserver.BaseRequestHandler, bind_and_activate=True, ipv6=True)
            server.server_close()


class HTTP01ServerTest(unittest.TestCase):
    """Tests for acme.standalone.HTTP01Server."""


    def setUp(self):
        self.account_key = jose.JWK.load(
            test_util.load_vector('rsa1024_key.pem'))
        self.resources = set() # type: Set

        from acme.standalone import HTTP01Server
        self.server = HTTP01Server(('', 0), resources=self.resources)

        self.port = self.server.socket.getsockname()[1]
        self.thread = threading.Thread(target=self.server.serve_forever)
        self.thread.start()

    def tearDown(self):
        self.server.shutdown()
        self.thread.join()

    def test_index(self):
        response = requests.get(
            'http://localhost:{0}'.format(self.port), verify=False)
        self.assertEqual(
            response.text, 'ACME client standalone challenge solver')
        self.assertTrue(response.ok)

    def test_404(self):
        response = requests.get(
            'http://localhost:{0}/foo'.format(self.port), verify=False)
        self.assertEqual(response.status_code, http_client.NOT_FOUND)

    def _test_http01(self, add):
        chall = challenges.HTTP01(token=(b'x' * 16))
        response, validation = chall.response_and_validation(self.account_key)

        from acme.standalone import HTTP01RequestHandler
        resource = HTTP01RequestHandler.HTTP01Resource(
            chall=chall, response=response, validation=validation)
        if add:
            self.resources.add(resource)
        return resource.response.simple_verify(
            resource.chall, 'localhost', self.account_key.public_key(),
            port=self.port)

    def test_http01_found(self):
        self.assertTrue(self._test_http01(add=True))

    def test_http01_not_found(self):
        self.assertFalse(self._test_http01(add=False))

    def test_timely_shutdown(self):
        from acme.standalone import HTTP01Server
        server = HTTP01Server(('', 0), resources=set(), timeout=0.05)
        server_thread = threading.Thread(target=server.serve_forever)
        server_thread.start()

        client = socket.socket()
        client.connect(('localhost', server.socket.getsockname()[1]))

        stop_thread = threading.Thread(target=server.shutdown)
        stop_thread.start()
        server_thread.join(5.)

        is_hung = server_thread.is_alive()
        try:
            client.shutdown(socket.SHUT_RDWR)
        except: # pragma: no cover, pylint: disable=bare-except
            # may raise error because socket could already be closed
            pass

        self.assertFalse(is_hung, msg='Server shutdown should not be hung')


@unittest.skipIf(not challenges.TLSALPN01.is_supported(), "pyOpenSSL too old")
class TLSALPN01ServerTest(unittest.TestCase):
    """Test for acme.standalone.TLSALPN01Server."""

    def setUp(self):
        self.certs = {b'localhost': (
            test_util.load_pyopenssl_private_key('rsa2048_key.pem'),
            test_util.load_cert('rsa2048_cert.pem'),
        )}
        # Use different certificate for challenge.
        self.challenge_certs = {b'localhost': (
            test_util.load_pyopenssl_private_key('rsa4096_key.pem'),
            test_util.load_cert('rsa4096_cert.pem'),
        )}
        from acme.standalone import TLSALPN01Server
        self.server = TLSALPN01Server(("localhost", 0), certs=self.certs,
                challenge_certs=self.challenge_certs)
        # pylint: disable=no-member
        self.thread = threading.Thread(target=self.server.serve_forever)
        self.thread.start()

    def tearDown(self):
        self.server.shutdown()  # pylint: disable=no-member
        self.thread.join()

    # TODO: This is not implemented yet, see comments in standalone.py
    # def test_certs(self):
    #    host, port = self.server.socket.getsockname()[:2]
    #    cert = crypto_util.probe_sni(
    #        b'localhost', host=host, port=port, timeout=1)
    #    # Expect normal cert when connecting without ALPN.
    #    self.assertEqual(jose.ComparableX509(cert),
    #                     jose.ComparableX509(self.certs[b'localhost'][1]))

    def test_challenge_certs(self):
        host, port = self.server.socket.getsockname()[:2]
        cert = crypto_util.probe_sni(
            b'localhost', host=host, port=port, timeout=1,
            alpn_protocols=[b"acme-tls/1"])
        #  Expect challenge cert when connecting with ALPN.
        self.assertEqual(
                jose.ComparableX509(cert),
                jose.ComparableX509(self.challenge_certs[b'localhost'][1])
        )

    def test_bad_alpn(self):
        host, port = self.server.socket.getsockname()[:2]
        with self.assertRaises(errors.Error):
            crypto_util.probe_sni(
                b'localhost', host=host, port=port, timeout=1,
                alpn_protocols=[b"bad-alpn"])


class BaseDualNetworkedServersTest(unittest.TestCase):
    """Test for acme.standalone.BaseDualNetworkedServers."""


    class SingleProtocolServer(socketserver.TCPServer):
        """Server that only serves on a single protocol. FreeBSD has this behavior for AF_INET6."""
        def __init__(self, *args, **kwargs):
            ipv6 = kwargs.pop("ipv6", False)
            if ipv6:
                self.address_family = socket.AF_INET6
                kwargs["bind_and_activate"] = False
            else:
                self.address_family = socket.AF_INET
            socketserver.TCPServer.__init__(self, *args, **kwargs)
            if ipv6:
                # NB: On Windows, socket.IPPROTO_IPV6 constant may be missing.
                # We use the corresponding value (41) instead.
                level = getattr(socket, "IPPROTO_IPV6", 41)
                self.socket.setsockopt(level, socket.IPV6_V6ONLY, 1)
                try:
                    self.server_bind()
                    self.server_activate()
                except:
                    self.server_close()
                    raise

    @mock.patch("socket.socket.bind")
    def test_fail_to_bind(self, mock_bind):
        mock_bind.side_effect = socket.error
        from acme.standalone import BaseDualNetworkedServers
        self.assertRaises(socket.error, BaseDualNetworkedServers,
                          BaseDualNetworkedServersTest.SingleProtocolServer,
                          ('', 0),
                          socketserver.BaseRequestHandler)

    def test_ports_equal(self):
        from acme.standalone import BaseDualNetworkedServers
        servers = BaseDualNetworkedServers(
            BaseDualNetworkedServersTest.SingleProtocolServer,
            ('', 0),
            socketserver.BaseRequestHandler)
        socknames = servers.getsocknames()
        prev_port = None
        # assert ports are equal
        for sockname in socknames:
            port = sockname[1]
            if prev_port:
                self.assertEqual(prev_port, port)
            prev_port = port


class HTTP01DualNetworkedServersTest(unittest.TestCase):
    """Tests for acme.standalone.HTTP01DualNetworkedServers."""

    def setUp(self):
        self.account_key = jose.JWK.load(
            test_util.load_vector('rsa1024_key.pem'))
        self.resources = set() # type: Set

        from acme.standalone import HTTP01DualNetworkedServers
        self.servers = HTTP01DualNetworkedServers(('', 0), resources=self.resources)

        self.port = self.servers.getsocknames()[0][1]
        self.servers.serve_forever()

    def tearDown(self):
        self.servers.shutdown_and_server_close()

    def test_index(self):
        response = requests.get(
            'http://localhost:{0}'.format(self.port), verify=False)
        self.assertEqual(
            response.text, 'ACME client standalone challenge solver')
        self.assertTrue(response.ok)

    def test_404(self):
        response = requests.get(
            'http://localhost:{0}/foo'.format(self.port), verify=False)
        self.assertEqual(response.status_code, http_client.NOT_FOUND)

    def _test_http01(self, add):
        chall = challenges.HTTP01(token=(b'x' * 16))
        response, validation = chall.response_and_validation(self.account_key)

        from acme.standalone import HTTP01RequestHandler
        resource = HTTP01RequestHandler.HTTP01Resource(
            chall=chall, response=response, validation=validation)
        if add:
            self.resources.add(resource)
        return resource.response.simple_verify(
            resource.chall, 'localhost', self.account_key.public_key(),
            port=self.port)

    def test_http01_found(self):
        self.assertTrue(self._test_http01(add=True))

    def test_http01_not_found(self):
        self.assertFalse(self._test_http01(add=False))


if __name__ == "__main__":
    unittest.main()  # pragma: no cover
